/*
 * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package modelendpoint

import (
	"context"
	"encoding/json"
	"io"
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"
)

func TestLoadLoRA_URLConstruction(t *testing.T) {
	tests := []struct {
		name            string
		baseAddress     string
		expectedURLPath string
	}{
		{
			name:            "address without trailing slash",
			baseAddress:     "http://10.0.1.5:9090",
			expectedURLPath: "/v1/loras",
		},
		{
			name:            "address with trailing slash",
			baseAddress:     "http://10.0.1.5:9090/",
			expectedURLPath: "/v1/loras",
		},
		{
			name:            "address with path",
			baseAddress:     "http://10.0.1.5:9090/api",
			expectedURLPath: "/api/v1/loras",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			// Create a test server that captures the request
			var capturedPath string
			server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				capturedPath = r.URL.Path
				w.WriteHeader(http.StatusOK)
			}))
			defer server.Close()

			client := NewClient()
			ctx := context.Background()

			// Call loadLoRA with test server URL
			_ = client.loadLoRA(ctx, server.URL+tt.baseAddress[len("http://10.0.1.5:9090"):], "test-model", "s3://bucket/model")

			if capturedPath != tt.expectedURLPath {
				t.Errorf("expected URL path %s, got %s", tt.expectedURLPath, capturedPath)
			}
		})
	}
}

func TestLoadLoRA_RequestBody(t *testing.T) {
	tests := []struct {
		name              string
		modelName         string
		sourceURI         string
		expectedLoraName  string
		expectedSourceURI string
	}{
		{
			name:              "basic lora load",
			modelName:         "my-lora",
			sourceURI:         "s3://bucket/model",
			expectedLoraName:  "my-lora",
			expectedSourceURI: "s3://bucket/model",
		},
		{
			name:              "huggingface lora",
			modelName:         "hf-lora",
			sourceURI:         "hf://org/model@v1.0",
			expectedLoraName:  "hf-lora",
			expectedSourceURI: "hf://org/model@v1.0",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			// Create a test server that captures the request body
			var capturedBody map[string]interface{}
			server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				body, _ := io.ReadAll(r.Body)
				_ = json.Unmarshal(body, &capturedBody)

				if r.Header.Get("Content-Type") != "application/json" {
					t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
				}

				w.WriteHeader(http.StatusOK)
			}))
			defer server.Close()

			client := NewClient()
			ctx := context.Background()

			// Call loadLoRA
			err := client.loadLoRA(ctx, server.URL, tt.modelName, tt.sourceURI)
			if err != nil {
				t.Fatalf("unexpected error: %v", err)
			}

			// Verify request body
			if capturedBody["lora_name"] != tt.expectedLoraName {
				t.Errorf("expected lora_name %s, got %v", tt.expectedLoraName, capturedBody["lora_name"])
			}

			source, ok := capturedBody["source"].(map[string]interface{})
			if !ok {
				t.Fatal("expected source to be a map")
			}

			if source["uri"] != tt.expectedSourceURI {
				t.Errorf("expected source URI %s, got %v", tt.expectedSourceURI, source["uri"])
			}
		})
	}
}

func TestLoadLoRA_ResponseHandling(t *testing.T) {
	tests := []struct {
		name          string
		statusCode    int
		responseBody  string
		expectError   bool
		errorContains string
	}{
		{
			name:        "success - 200 OK",
			statusCode:  http.StatusOK,
			expectError: false,
		},
		{
			name:        "success - 201 Created",
			statusCode:  http.StatusCreated,
			expectError: false,
		},
		{
			name:          "failure - 400 Bad Request",
			statusCode:    http.StatusBadRequest,
			responseBody:  "Invalid LoRA",
			expectError:   true,
			errorContains: "400",
		},
		{
			name:          "failure - 404 Not Found",
			statusCode:    http.StatusNotFound,
			responseBody:  "Endpoint not found",
			expectError:   true,
			errorContains: "404",
		},
		{
			name:          "failure - 500 Internal Server Error",
			statusCode:    http.StatusInternalServerError,
			responseBody:  "Server error",
			expectError:   true,
			errorContains: "500",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				w.WriteHeader(tt.statusCode)
				if tt.responseBody != "" {
					_, _ = w.Write([]byte(tt.responseBody))
				}
			}))
			defer server.Close()

			client := NewClient()
			ctx := context.Background()

			err := client.loadLoRA(ctx, server.URL, "test-model", "s3://bucket/model")

			if tt.expectError {
				if err == nil {
					t.Error("expected error but got none")
				} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
					t.Errorf("expected error to contain %q, got %v", tt.errorContains, err)
				}
			} else {
				if err != nil {
					t.Errorf("expected no error but got: %v", err)
				}
			}
		})
	}
}

func TestUnloadLoRA_URLConstruction(t *testing.T) {
	tests := []struct {
		name            string
		modelName       string
		expectedURLPath string
	}{
		{
			name:            "simple model name",
			modelName:       "my-lora",
			expectedURLPath: "/v1/loras/my-lora",
		},
		{
			name:            "model name with special chars",
			modelName:       "my-lora-v1.0",
			expectedURLPath: "/v1/loras/my-lora-v1.0",
		},
		{
			name:            "model name with slashes (URL encoded)",
			modelName:       "org/model",
			expectedURLPath: "/v1/loras/org/model",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			// Create a test server that captures the request
			var capturedPath string
			var capturedMethod string
			server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				capturedPath = r.URL.Path
				capturedMethod = r.Method
				w.WriteHeader(http.StatusOK)
			}))
			defer server.Close()

			client := NewClient()
			ctx := context.Background()

			// Call unloadLoRA
			err := client.unloadLoRA(ctx, server.URL, tt.modelName)
			if err != nil {
				t.Fatalf("unexpected error: %v", err)
			}

			if capturedMethod != "DELETE" {
				t.Errorf("expected DELETE method, got %s", capturedMethod)
			}

			if capturedPath != tt.expectedURLPath {
				t.Errorf("expected URL path %s, got %s", tt.expectedURLPath, capturedPath)
			}
		})
	}
}

func TestUnloadLoRA_ResponseHandling(t *testing.T) {
	tests := []struct {
		name          string
		statusCode    int
		responseBody  string
		expectError   bool
		errorContains string
	}{
		{
			name:        "success - 200 OK",
			statusCode:  http.StatusOK,
			expectError: false,
		},
		{
			name:        "success - 204 No Content",
			statusCode:  http.StatusNoContent,
			expectError: false,
		},
		{
			name:          "failure - 404 Not Found",
			statusCode:    http.StatusNotFound,
			responseBody:  "LoRA not found",
			expectError:   true,
			errorContains: "404",
		},
		{
			name:          "failure - 500 Internal Server Error",
			statusCode:    http.StatusInternalServerError,
			responseBody:  "Failed to unload",
			expectError:   true,
			errorContains: "500",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				w.WriteHeader(tt.statusCode)
				if tt.responseBody != "" {
					_, _ = w.Write([]byte(tt.responseBody))
				}
			}))
			defer server.Close()

			client := NewClient()
			ctx := context.Background()

			err := client.unloadLoRA(ctx, server.URL, "test-model")

			if tt.expectError {
				if err == nil {
					t.Error("expected error but got none")
				} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
					t.Errorf("expected error to contain %q, got %v", tt.errorContains, err)
				}
			} else {
				if err != nil {
					t.Errorf("expected no error but got: %v", err)
				}
			}
		})
	}
}
